{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<center>\n",
    "<img src=\"../../img/ods_stickers.jpg\">\n",
    "## Открытый курс по машинному обучению\n",
    "<center>Автор материала: Ефремова Дина (@ldinka)."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# <center>Исследование возможностей BigARTM</center>\n",
    "\n",
    "## <center>Тематическое моделирование с помощью BigARTM</center>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Интро"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "BigARTM — библиотека, предназначенная для тематической категоризации текстов; делает разбиение на темы без «учителя».\n",
    "\n",
    "Я собираюсь использовать эту библиотеку для собственных нужд в будущем, но так как она не предназначена для обучения с учителем, решила, что для начала ее стоит протестировать на какой-нибудь уже размеченной выборке. Для этих целей был использован датасет \"20 news groups\".\n",
    "\n",
    "Идея экперимента такова:\n",
    "- делим выборку на обучающую и тестовую;\n",
    "- обучаем модель на обучающей выборке;\n",
    "- «подгоняем» выделенные темы под действительные;\n",
    "- смотрим, насколько хорошо прошло разбиение;\n",
    "- тестируем модель на тестовой выборке."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Поехали!"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**Внимание!** Данный проект был реализован с помощью Python 3.6 и BigARTM 0.9.0. Методы, рассмотренные здесь, могут отличаться от методов в других версиях библиотеки."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<img src=\"../../img/bigartm_logo.png\"/>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### <font color=\"lightgrey\">Не</font>множко теории"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "У нас есть словарь терминов $W = \\{w \\in W\\}$, который представляет из себя мешок слов, биграмм или n-грамм;\n",
    "\n",
    "Есть коллекция документов $D = \\{d \\in D\\}$, где $d \\subset W$;\n",
    "\n",
    "Есть известное множество тем $T = \\{t \\in T\\}$;\n",
    "\n",
    "$n_{dw}$ — сколько раз термин $w$ встретился в документе $d$;\n",
    "\n",
    "$n_{d}$ — длина документа $d$."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Мы считаем, что существует матрица $\\Phi$ распределения терминов $w$ в темах $t$: (фи) $\\Phi = (\\phi_{wt})$\n",
    "\n",
    "и матрица распределения тем $t$ в документах $d$: (тета) $\\Theta = (\\theta_{td})$,\n",
    "\n",
    "переумножение которых дает нам тематическую модель, или, другими словами, представление наблюдаемого условного распределения $p(w|d)$ терминов $w$ в документах $d$ коллекции $D$:\n",
    "\n",
    "<center>$\\large p(w|d) = \\Phi \\Theta$</center>\n",
    "\n",
    "<center>$$\\large p(w|d) = \\sum_{t \\in T} \\phi_{wt} \\theta_{td}$$</center>\n",
    "\n",
    "где $\\phi_{wt} = p(w|t)$ — вероятности терминов $w$ в каждой теме $t$\n",
    "\n",
    "и $\\theta_{td} = p(t|d)$ — вероятности тем $t$ в каждом документе $d$."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<img src=\"../../img/phi_theta.png\"/>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Нам известны наблюдаемые частоты терминов в документах, это:\n",
    "\n",
    "<center>$ \\large \\hat{p}(w|d) = \\frac {n_{dw}} {n_{d}} $</center>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Таким образом, наша задача тематического моделирования становится задачей стохастического матричного разложения матрицы $\\hat{p}(w|d)$ на стохастические матрицы $\\Phi$ и $\\Theta$.\n",
    "\n",
    "Напомню, что матрица является стохастической, если каждый ее столбец представляет дискретное распределение вероятностей, сумма значений каждого столбца равна 1."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Воспользовавшись принципом максимального правдоподобия, т. е. максимизируя логарифм правдоподобия, мы получим:\n",
    "\n",
    "<center>$\n",
    "\\begin{cases}\n",
    "\\sum_{d \\in D} \\sum_{w \\in d} n_{dw} \\ln \\sum_{t \\in T} \\phi_{wt} \\theta_{td} \\rightarrow \\max\\limits_{\\Phi,\\Theta};\\\\\n",
    "\\sum_{w \\in W} \\phi_{wt} = 1, \\qquad \\phi_{wt}\\geq0;\\\\\n",
    "\\sum_{t \\in T} \\theta_{td} = 1, \\quad\\quad\\;\\; \\theta_{td}\\geq0.\n",
    "\\end{cases}\n",
    "$</center>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Чтобы из множества решений выбрать наиболее подходящее, введем критерий регуляризации $R(\\Phi, \\Theta)$:\n",
    "\n",
    "<center>$\n",
    "\\begin{cases}\n",
    "\\sum_{d \\in D} \\sum_{w \\in d} n_{dw} \\ln \\sum_{t \\in T} \\phi_{wt} \\theta_{td} + R(\\Phi, \\Theta) \\rightarrow \\max\\limits_{\\Phi,\\Theta};\\\\\n",
    "\\sum_{w \\in W} \\phi_{wt} = 1, \\qquad \\phi_{wt}\\geq0;\\\\\n",
    "\\sum_{t \\in T} \\theta_{td} = 1, \\quad\\quad\\;\\; \\theta_{td}\\geq0.\n",
    "\\end{cases}\n",
    "$</center>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Два наиболее известных частных случая этой системы уравнений:\n",
    "- **PLSA**, вероятностный латентный семантический анализ, когда $R(\\Phi, \\Theta) = 0$\n",
    "- **LDA**, латентное размещение Дирихле:\n",
    "$$R(\\Phi, \\Theta) = \\sum_{t,w} (\\beta_{w} - 1) \\ln \\phi_{wt} + \\sum_{d,t} (\\alpha_{t} - 1) \\ln \\theta_{td} $$\n",
    "где $\\beta_{w} > 0$, $\\alpha_{t} > 0$ — параметры регуляризатора."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Однако оказывается запас неединственности решения настолько большой, что на модель можно накладывать сразу несколько ограничений, такой подход называется **ARTM**, или аддитивной регуляризацией тематических моделей:\n",
    "\n",
    "<center>$\n",
    "\\begin{cases}\n",
    "\\sum_{d,w} n_{dw} \\ln \\sum_{t} \\phi_{wt} \\theta_{td} + \\sum_{i=1}^k \\tau_{i} R_{i}(\\Phi, \\Theta) \\rightarrow \\max\\limits_{\\Phi,\\Theta};\\\\\n",
    "\\sum_{w \\in W} \\phi_{wt} = 1, \\qquad \\phi_{wt}\\geq0;\\\\\n",
    "\\sum_{t \\in T} \\theta_{td} = 1, \\quad\\quad\\;\\; \\theta_{td}\\geq0.\n",
    "\\end{cases}\n",
    "$</center>\n",
    "\n",
    "где $\\tau_{i}$ — коэффициенты регуляризации."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Теперь давайте познакомимся с библиотекой BigARTM и разберем еще некоторые аспекты тематического моделирования на ходу."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Если Вас очень сильно заинтересовала теоретическая часть категоризации текстов и тематического моделирования, рекомендую посмотреть видеолекции из курса Яндекса на Coursera «Поиск структуры в данных» четвертой недели: <a href=\"https://www.coursera.org/learn/unsupervised-learning/home/week/4\">Тематическое моделирование</a>."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### BigARTM"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Установка"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Естественно, для начала работы с библиотекой ее надо установить. Вот несколько видео, которые рассказывают, как это сделать в зависимости от вашей операционной системы:\n",
    "- <a href=\"https://www.coursera.org/learn/unsupervised-learning/lecture/qmsFm/ustanovka-bigartm-v-windows\">Установка BigARTM в Windows</a>\n",
    "- <a href=\"https://www.coursera.org/learn/unsupervised-learning/lecture/zPyO0/ustanovka-bigartm-v-linux-mint\">Установка BigARTM в Linux</a>\n",
    "- <a href=\"https://www.coursera.org/learn/unsupervised-learning/lecture/nuIhL/ustanovka-bigartm-v-mac-os-x\">Установка BigARTM в Mac OS X</a>\n",
    "\n",
    "Либо можно воспользоваться инструкцией с официального сайта, которая, скорее всего, будет гораздо актуальнее: <a href=\"https://bigartm.readthedocs.io/en/stable/installation/index.html\">здесь</a>. Там же указано, как можно установить BigARTM в качестве <a href=\"https://bigartm.readthedocs.io/en/stable/installation/docker.html\">Docker-контейнера</a>."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Использование BigARTM"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import artm\n",
    "import re\n",
    "import numpy as np\n",
    "import seaborn as sns; sns.set()\n",
    "\n",
    "from sklearn.metrics import classification_report, confusion_matrix\n",
    "from sklearn.preprocessing import normalize\n",
    "from sklearn.model_selection import train_test_split\n",
    "from sklearn.metrics import accuracy_score\n",
    "from matplotlib import pyplot as plt\n",
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "artm.version()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": true
   },
   "source": [
    "Скачаем датасет ***the 20 news groups*** с заранее известным количеством категорий новостей:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.datasets import fetch_20newsgroups"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "newsgroups = fetch_20newsgroups('../../data/news_data')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "newsgroups['target_names']"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": true
   },
   "source": [
    "Приведем данные к формату *Vowpal Wabbit*. Так как BigARTM не рассчитан на обучение с учителем, то мы поступим следующим образом:\n",
    "- обучим модель на всем корпусе текстов;\n",
    "- выделим ключевые слова тем и по ним определим, к какой теме они скорее всего относятся;\n",
    "- сравним наши полученные результаты разбиения с истинными значенями."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "TEXT_FIELD = \"text\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def to_vw_format(document, label=None):\n",
    "    return str(label or '0') + ' |' + TEXT_FIELD + ' ' + ' '.join(re.findall('\\w{3,}', document.lower())) + '\\n'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "all_documents = newsgroups['data']\n",
    "all_targets = newsgroups['target']\n",
    "len(newsgroups['target'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_documents, test_documents, train_labels, test_labels = \\\n",
    "    train_test_split(all_documents, all_targets, random_state=7)\n",
    "\n",
    "with open('../../data/news_data/20news_train_mult.vw', 'w') as vw_train_data:\n",
    "    for text, target in zip(train_documents, train_labels):\n",
    "        vw_train_data.write(to_vw_format(text, target))\n",
    "with open('../../data/news_data/20news_test_mult.vw', 'w') as vw_test_data:\n",
    "    for text in test_documents:\n",
    "        vw_test_data.write(to_vw_format(text))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Загрузим данные в необходимый для BigARTM формат:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "batch_vectorizer = artm.BatchVectorizer(data_path=\"../../data/news_data/20news_train_mult.vw\",\n",
    "                                        data_format=\"vowpal_wabbit\",\n",
    "                                        target_folder=\"news_batches\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Данные в BigARTM загружаются порционно, укажем в \n",
    "- *data_path* путь к обучающей выборке,\n",
    "- *data_format* — формат наших данных, может быть:\n",
    "    * *bow_n_wd* — это вектор $n_{wd}$ в виду массива *numpy.ndarray*, также необходимо передать соответствующий словарь терминов, где ключ — это индекс вектора *numpy.ndarray* $n_{wd}$, а значение — соответствующий токен.\n",
    "    ```python\n",
    "    batch_vectorizer = artm.BatchVectorizer(data_format='bow_n_wd',\n",
    "                                              n_wd=n_wd,\n",
    "                                              vocabulary=vocabulary)\n",
    "    ```\n",
    "    * *vowpal_wabbit* — формат Vowpal Wabbit;\n",
    "    * *bow_uci* — UCI формат (например, с *vocab.my_collection.txt* и *docword.my_collection.txt* файлами):\n",
    "    ```python\n",
    "    batch_vectorizer = artm.BatchVectorizer(data_path='',\n",
    "                                              data_format='bow_uci',\n",
    "                                              collection_name='my_collection',\n",
    "                                              target_folder='my_collection_batches')\n",
    "    ```\n",
    "    * *batches* — данные, уже сконверченные в батчи с помощью BigARTM;\n",
    "- *target_folder* — путь для сохранения батчей.\n",
    "\n",
    "Пока это все параметры, что нам нужны для загрузки наших данных.\n",
    "\n",
    "После того, как BigARTM создал батчи из данных, можно использовать их для загрузки:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "batch_vectorizer = artm.BatchVectorizer(data_path=\"news_batches\", data_format='batches')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": true
   },
   "source": [
    "Инициируем модель с известным нам количеством тем. Количество тем — это гиперпараметр, поэтому если он заранее нам неизвестен, то его необходимо настраивать, т. е. брать такое количество тем, при котором разбиение кажется наиболее удачным.\n",
    "\n",
    "**Важно!** У нас 20 предметных тем, однако некоторые из них довольно узкоспециализированны и смежны, как например 'comp.os.ms-windows.misc' и 'comp.windows.x', или 'comp.sys.ibm.pc.hardware' и 'comp.sys.mac.hardware', тогда как другие размыты и всеобъемлющи: talk.politics.misc' и 'talk.religion.misc'.\n",
    "\n",
    "Скорее всего, нам не удастся в чистом виде выделить все 20 тем — некоторые из них окажутся слитными, а другие наоборот раздробятся на более мелкие. Поэтому мы попробуем построить 40 «предметных» тем и одну фоновую. Чем больше вы будем строить категорий, тем лучше мы сможем подстроиться под данные, однако это довольно трудоемкое занятие сидеть потом и распределять в получившиеся темы по реальным категориям (<strike>я правда очень-очень задолбалась!</strike>).\n",
    "\n",
    "Зачем нужны фоновые темы? Дело в том, что наличие общей лексики в темах приводит к плохой ее интерпретируемости. Выделив общую лексику в отдельную тему, мы сильно снизим ее количество в предметных темах, таким образом оставив там лексическое ядро, т. е. ключевые слова, которые данную тему характеризуют. Также этим преобразованием мы снизим коррелированность тем, они станут более независимыми и различимыми."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "T = 41\n",
    "model_artm = artm.ARTM(num_topics=T,\n",
    "                       topic_names=[str(i) for i in range(T)],\n",
    "                       class_ids={TEXT_FIELD:1}, \n",
    "                       num_document_passes=1,\n",
    "                       reuse_theta=True,\n",
    "                       cache_theta=True,\n",
    "                       seed=4)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": true
   },
   "source": [
    "Передаем в модель следующие параметры:\n",
    "- *num_topics* — количество тем;\n",
    "- *topic_names* — названия тем;\n",
    "- *class_ids* — название модальности и ее вес. Дело в том, что кроме самих текстов, в данных может содержаться такая информация, как автор, изображения, ссылки на другие документы и т. д., по которым также можно обучать модель;\n",
    "- *num_document_passes* — количество проходов при обучении модели;\n",
    "- *reuse_theta* — переиспользовать ли матрицу $\\Theta$ с предыдущей итерации;\n",
    "- *cache_theta* — сохранить ли матрицу $\\Theta$ в модели, чтобы в дальнейшем ее использовать.\n",
    "\n",
    "Далее необходимо создать словарь; передадим ему какое-нибудь название, которое будем использовать в будущем для работы с этим словарем."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "DICTIONARY_NAME = 'dictionary'\n",
    "\n",
    "dictionary = artm.Dictionary(DICTIONARY_NAME)\n",
    "dictionary.gather(batch_vectorizer.data_path)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": true
   },
   "source": [
    "Инициализируем модель с тем именем словаря, что мы передали выше, можно зафиксировать *random seed* для вопроизводимости результатов:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "np.random.seed(1)\n",
    "model_artm.initialize(DICTIONARY_NAME)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Добавим к модели несколько метрик:\n",
    "- перплексию (*PerplexityScore*), чтобы индентифицировать сходимость модели\n",
    "    * Перплексия — это известная в вычислительной лингвистике мера качества модели языка. Можно сказать, что это мера неопределенности или различности слов в тексте.\n",
    "- специальный *score* ключевых слов (*TopTokensScore*), чтобы в дальнейшем мы могли идентифицировать по ним наши тематики;\n",
    "- разреженность матрицы $\\Phi$ (*SparsityPhiScore*);\n",
    "- разреженность матрицы $\\Theta$ (*SparsityThetaScore*)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model_artm.scores.add(artm.PerplexityScore(name='perplexity_score',\n",
    "                                           dictionary=DICTIONARY_NAME))\n",
    "model_artm.scores.add(artm.SparsityPhiScore(name='sparsity_phi_score', class_id=\"text\"))\n",
    "model_artm.scores.add(artm.SparsityThetaScore(name='sparsity_theta_score'))\n",
    "model_artm.scores.add(artm.TopTokensScore(name=\"top_words\", num_tokens=15, class_id=TEXT_FIELD))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Следующая операция *fit_offline* займет некоторое время, мы будем обучать модель в режиме *offline* в 40 проходов. Количество проходов влияет на сходимость модели: чем их больше, тем лучше сходится модель."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%%time\n",
    "\n",
    "model_artm.fit_offline(batch_vectorizer=batch_vectorizer, num_collection_passes=40)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Построим график сходимости модели и увидим, что модель сходится довольно быстро:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.plot(model_artm.score_tracker[\"perplexity_score\"].value);"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Выведем значения разреженности матриц:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print('Phi', model_artm.score_tracker[\"sparsity_phi_score\"].last_value)\n",
    "print('Theta', model_artm.score_tracker[\"sparsity_theta_score\"].last_value)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "После того, как модель сошлась, добавим к ней регуляризаторы. Для начала сглаживающий регуляризатор — это *SmoothSparsePhiRegularizer* с большим положительным коэффициентом $\\tau$, который нужно применить только к фоновой теме, чтобы выделить в нее как можно больше общей лексики. Пусть тема с последним индексом будет фоновой, передадим в *topic_names* этот индекс:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model_artm.regularizers.add(artm.SmoothSparsePhiRegularizer(name='SparsePhi', \n",
    "                                                            tau=1e5, \n",
    "                                                            dictionary=dictionary, \n",
    "                                                            class_ids=TEXT_FIELD,\n",
    "                                                            topic_names=str(T-1)))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Дообучим модель, сделав 20 проходов по ней с новым регуляризатором:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%%time\n",
    "\n",
    "model_artm.fit_offline(batch_vectorizer=batch_vectorizer, num_collection_passes=20)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Выведем значения разреженности матриц, заметим, что значение для $\\Theta$ немного увеличилось:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print('Phi', model_artm.score_tracker[\"sparsity_phi_score\"].last_value)\n",
    "print('Theta', model_artm.score_tracker[\"sparsity_theta_score\"].last_value)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Теперь добавим к модели разреживающий регуляризатор, это тот же *SmoothSparsePhiRegularizer* резуляризатор, только с отрицательным значением $\\tau$ и примененный ко всем предметным темам:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model_artm.regularizers.add(artm.SmoothSparsePhiRegularizer(name='SparsePhi2', \n",
    "                                                            tau=-5e5, \n",
    "                                                            dictionary=dictionary, \n",
    "                                                            class_ids=TEXT_FIELD,\n",
    "                                                            topic_names=[str(i) for i in range(T-1)]),\n",
    "                                                            overwrite=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%%time\n",
    "\n",
    "model_artm.fit_offline(batch_vectorizer=batch_vectorizer, num_collection_passes=20)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Видим, что значения разреженности увеличились еще больше:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(model_artm.score_tracker[\"sparsity_phi_score\"].last_value)\n",
    "print(model_artm.score_tracker[\"sparsity_theta_score\"].last_value)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Посмотрим, сколько категорий-строк матрицы $\\Theta$ после регуляризации осталось, т. е. не занулилось/выродилось. И это одна категория:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "len(model_artm.score_tracker[\"top_words\"].last_tokens.keys())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Теперь выведем ключевые слова тем, чтобы определить, каким образом прошло разбиение, и сделать соответствие с нашим начальным списком тем:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for topic_name in model_artm.score_tracker[\"top_words\"].last_tokens.keys():\n",
    "    tokens = model_artm.score_tracker[\"top_words\"].last_tokens\n",
    "    res_str = topic_name + ': ' + ', '.join(tokens[topic_name])\n",
    "    print(res_str)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Далее мы будем подгонять разбиение под действительные темы с помощью *confusion matrix*."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "target_dict = {\n",
    "    'alt.atheism': 0,\n",
    "    'comp.graphics': 1,\n",
    "    'comp.os.ms-windows.misc': 2,\n",
    "    'comp.sys.ibm.pc.hardware': 3,\n",
    "    'comp.sys.mac.hardware': 4,\n",
    "    'comp.windows.x': 5,\n",
    "    'misc.forsale': 6,\n",
    "    'rec.autos': 7,\n",
    "    'rec.motorcycles': 8,\n",
    "    'rec.sport.baseball': 9,\n",
    "    'rec.sport.hockey': 10,\n",
    "    'sci.crypt': 11,\n",
    "    'sci.electronics': 12,\n",
    "    'sci.med': 13,\n",
    "    'sci.space': 14,\n",
    "    'soc.religion.christian': 15,\n",
    "    'talk.politics.guns': 16,\n",
    "    'talk.politics.mideast': 17,\n",
    "    'talk.politics.misc': 18,\n",
    "    'talk.religion.misc': 19\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "mixed = [\n",
    "    'comp.sys.ibm.pc.hardware',\n",
    "    'talk.politics.mideast',\n",
    "    'sci.electronics',\n",
    "    'rec.sport.hockey',\n",
    "\n",
    "    'sci.med',\n",
    "    'rec.motorcycles',\n",
    "    'comp.graphics',\n",
    "    'rec.sport.hockey',\n",
    "\n",
    "    'talk.politics.mideast',\n",
    "    'talk.religion.misc',\n",
    "    'rec.autos',\n",
    "    'comp.graphics',\n",
    "\n",
    "    'sci.space',\n",
    "    'soc.religion.christian',\n",
    "    'comp.os.ms-windows.misc',\n",
    "    'sci.crypt',\n",
    "\n",
    "    'comp.windows.x',\n",
    "    'misc.forsale',\n",
    "    'sci.space',\n",
    "    'sci.crypt',\n",
    "\n",
    "    'talk.religion.misc',\n",
    "    'alt.atheism',\n",
    "    'comp.os.ms-windows.misc',\n",
    "    'alt.atheism',\n",
    "    \n",
    "    'sci.med',\n",
    "    'comp.os.ms-windows.misc',\n",
    "    'soc.religion.christian',\n",
    "    'talk.politics.guns',\n",
    "\n",
    "    'rec.autos',\n",
    "    'rec.autos',\n",
    "    'talk.politics.mideast',\n",
    "    'rec.sport.baseball',\n",
    "\n",
    "    'talk.religion.misc',\n",
    "    'talk.politics.misc',\n",
    "    'rec.sport.hockey',\n",
    "    'comp.sys.mac.hardware',\n",
    "\n",
    "    'misc.forsale',\n",
    "    'sci.space',\n",
    "    'talk.politics.guns',\n",
    "    'rec.autos',\n",
    "    \n",
    "    '-'\n",
    "]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Построим небольшой отчет о правильности нашего разбиения:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "theta_train = model_artm.get_theta()\n",
    "model_labels = []\n",
    "keys = np.sort([int(i) for i in theta_train.keys()])\n",
    "for i in keys:\n",
    "    max_val = 0\n",
    "    max_idx = 0\n",
    "    for j in theta_train[i].keys():\n",
    "        if j == str(T-1):\n",
    "            continue\n",
    "        if theta_train[i][j] > max_val:\n",
    "            max_val = theta_train[i][j]\n",
    "            max_idx = j\n",
    "    topic = mixed[int(max_idx)]\n",
    "    if topic == '-':\n",
    "        print(i, '-')\n",
    "    label = target_dict[topic]\n",
    "    model_labels.append(label)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(classification_report(train_labels, model_labels))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(classification_report(train_labels, model_labels))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "mat = confusion_matrix(train_labels, model_labels)\n",
    "sns.heatmap(mat.T, annot=True, fmt='d', cbar=False)\n",
    "plt.xlabel('True label')\n",
    "plt.ylabel('Predicted label');"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "accuracy_score(train_labels, model_labels)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Нам удалось добиться 80% *accuracy*. По матрице ответов мы видим, что для модели темы *comp.sys.ibm.pc.hardware* и *comp.sys.mac.hardware* практически не различимы (<strike>честно говоря, для меня тоже</strike>), в остальном все более или менее прилично.\n",
    "\n",
    "Проверим модель на тестовой выборке:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "batch_vectorizer_test = artm.BatchVectorizer(data_path=\"../../data/news_data/20news_test_mult.vw\",\n",
    "                                             data_format=\"vowpal_wabbit\",\n",
    "                                             target_folder=\"news_batches_test\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "theta_test = model_artm.transform(batch_vectorizer_test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_score = []\n",
    "for i in range(len(theta_test.keys())):\n",
    "    max_val = 0\n",
    "    max_idx = 0\n",
    "    for j in theta_test[i].keys():\n",
    "        if j == str(T-1):\n",
    "            continue\n",
    "        if theta_test[i][j] > max_val:\n",
    "            max_val = theta_test[i][j]\n",
    "            max_idx = j\n",
    "    topic = mixed[int(max_idx)]\n",
    "    label = target_dict[topic]\n",
    "    test_score.append(label)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(classification_report(test_labels, test_score))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "mat = confusion_matrix(test_labels, test_score)\n",
    "sns.heatmap(mat.T, annot=True, fmt='d', cbar=False)\n",
    "plt.xlabel('True label')\n",
    "plt.ylabel('Predicted label');"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "accuracy_score(test_labels, test_score)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Итого почти 77%, незначительно хуже, чем на обучающей."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**Вывод:** безумно много времени пришлось потратить на подгонку категорий к реальным темам, но в итоге я осталась довольна результатом. Такие смежные темы, как *alt.atheism*/*soc.religion.christian*/*talk.religion.misc* или *talk.politics.guns*/*talk.politics.mideast*/*talk.politics.misc* разделились вполне неплохо. Думаю, что я все-таки попробую использовать BigARTM в будущем для своих <strike>корыстных</strike> целей."
   ]
  }
 ],
 "metadata": {
  "anaconda-cloud": {},
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.1"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
